{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# WebDataset + Distributed PyTorch Training\n",
    "\n",
    "This notebook illustrates how to use the Web Indexed Dataset (`wids`) library for distributed PyTorch training using `DistributedDataParallel`.\n",
    "\n",
    "Using `webdataset` results in training code that is almost identical to plain PyTorch except for the dataset creation.\n",
    "Since `WebDataset` is an iterable dataset, you need to account for that when creating the `DataLoader`. Furthermore, for\n",
    "distributed training, easy restarts, etc., it is convenient to use a resampled dataset; this is in contrast to\n",
    "sampling without replacement for each epoch as used more commonly for small, local training. (If you want to use\n",
    "sampling without replacement with webdataset format datasets, see the companion `wids`-based training notebooks.)\n",
    "\n",
    "Training with `WebDataset` can be carried out completely without local storage; this is the usual setup in the cloud\n",
    "and on high speed compute clusters. When running locally on a desktop, you may want to cache the data, and for that,\n",
    "you set a `cache_dir` directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.distributed as dist\n",
    "from torch.nn.parallel import DistributedDataParallel\n",
    "from torchvision.models import resnet50\n",
    "from torchvision import transforms\n",
    "import ray\n",
    "import webdataset as wds\n",
    "import dataclasses\n",
    "import time\n",
    "from collections import deque\n",
    "from typing import Optional\n",
    "\n",
    "\n",
    "def enumerate_report(seq, delta, growth=1.0):\n",
    "    last = 0\n",
    "    count = 0\n",
    "    for count, item in enumerate(seq):\n",
    "        now = time.time()\n",
    "        if now - last > delta:\n",
    "            last = now\n",
    "            yield count, item, True\n",
    "        else:\n",
    "            yield count, item, False\n",
    "        delta *= growth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "epochs = 10\n",
    "maxsteps = int(1e12)\n",
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Loading for Distributed Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Datasets are just collections of shards in the cloud. We usually specify\n",
    "# them using {lo..hi} brace notation (there is also a YAML spec for more complex\n",
    "# datasets).\n",
    "\n",
    "bucket = \"https://storage.googleapis.com/webdataset/fake-imagenet\"\n",
    "trainset_url = bucket + \"/imagenet-train-{000000..001281}.tar\"\n",
    "valset_url = bucket + \"/imagenet-val-{000000..000049}.tar\"\n",
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If running in the cloud or with a fast network storage system, we don't\n",
    "# need any local storage.\n",
    "\n",
    "if \"google.colab\" in sys.modules:\n",
    "    cache_dir = None\n",
    "    print(\"running on colab, streaming data directly from storage\")\n",
    "else:\n",
    "    cache_dir = \"./_cache\"\n",
    "    print(f\"not running in colab, caching data locally in {cache_dir}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The dataloader pipeline is a fairly typical `IterableDataset` pipeline\n",
    "# for PyTorch\n",
    "\n",
    "\n",
    "def make_dataloader_train():\n",
    "    \"\"\"Create a DataLoader for training on the ImageNet dataset using WebDataset.\"\"\"\n",
    "\n",
    "    transform = transforms.Compose(\n",
    "        [\n",
    "            transforms.RandomResizedCrop(224),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    def make_sample(sample):\n",
    "        return transform(sample[\"jpg\"]), sample[\"cls\"]\n",
    "\n",
    "    # This is the basic WebDataset definition: it starts with a URL and add shuffling,\n",
    "    # decoding, and augmentation. Note `resampled=True`; this is essential for\n",
    "    # distributed training to work correctly.\n",
    "    trainset = wds.WebDataset(\n",
    "        trainset_url,\n",
    "        resampled=True,\n",
    "        shardshuffle=True,\n",
    "        cache_dir=cache_dir,\n",
    "        nodesplitter=wds.split_by_node,\n",
    "    )\n",
    "    trainset = trainset.shuffle(1000).decode(\"pil\").map(make_sample)\n",
    "\n",
    "    # For IterableDataset objects, the batching needs to happen in the dataset.\n",
    "    trainset = trainset.batched(64)\n",
    "    trainloader = wds.WebLoader(trainset, batch_size=None, num_workers=4)\n",
    "\n",
    "    # We unbatch, shuffle, and rebatch to mix samples from different workers.\n",
    "    trainloader = trainloader.unbatched().shuffle(1000).batched(batch_size)\n",
    "\n",
    "    # A resampled dataset is infinite size, but we can recreate a fixed epoch length.\n",
    "    trainloader = trainloader.with_epoch(1282 * 100 // 64)\n",
    "\n",
    "    return trainloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's try it out\n",
    "\n",
    "\n",
    "def make_dataloader(split=\"train\"):\n",
    "    \"\"\"Make a dataloader for training or validation.\"\"\"\n",
    "    if split == \"train\":\n",
    "        return make_dataloader_train()\n",
    "    elif split == \"val\":\n",
    "        return make_dataloader_val()  # not implemented for this notebook\n",
    "    else:\n",
    "        raise ValueError(f\"unknown split {split}\")\n",
    "\n",
    "\n",
    "# Try it out.\n",
    "os.environ[\"GOPEN_VERBOSE\"] = \"1\"\n",
    "sample = next(iter(make_dataloader()))\n",
    "print(sample[0].shape, sample[1].shape)\n",
    "os.environ[\"GOPEN_VERBOSE\"] = \"0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Standard PyTorch Training\n",
    "\n",
    "This is completely standard PyTorch training; nothing changes by using WebDataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We gather all the configuration info into a single typed dataclass.\n",
    "\n",
    "\n",
    "@dataclasses.dataclass\n",
    "class Config:\n",
    "    epochs: int = 1\n",
    "    max_steps: int = int(1e18)\n",
    "    lr: float = 0.001\n",
    "    momentum: float = 0.9\n",
    "    rank: Optional[int] = None\n",
    "    world_size: int = 2\n",
    "    backend: str = \"nccl\"\n",
    "    master_addr: str = \"localhost\"\n",
    "    master_port: str = \"12355\"\n",
    "    report_s: float = 15.0\n",
    "    report_growth: float = 1.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(config):\n",
    "    # Define the model, loss function, and optimizer\n",
    "    model = resnet50(pretrained=False).cuda()\n",
    "    if config.rank is not None:\n",
    "        model = DistributedDataParallel(model)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=config.lr)\n",
    "\n",
    "    # Data loading code\n",
    "    trainloader = make_dataloader(split=\"train\")\n",
    "\n",
    "    losses, accuracies, steps = deque(maxlen=100), deque(maxlen=100), 0\n",
    "\n",
    "    # Training loop\n",
    "    for epoch in range(config.epochs):\n",
    "        for i, data, verbose in enumerate_report(trainloader, config.report_s):\n",
    "            inputs, labels = data[0].cuda(), data[1].cuda()\n",
    "\n",
    "            # zero the parameter gradients\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # forward + backward + optimize\n",
    "            outputs = model(inputs)\n",
    "\n",
    "            # update statistics\n",
    "            loss = loss_fn(outputs, labels)\n",
    "            accuracy = (outputs.argmax(1) == labels).float().mean()  # calculate accuracy\n",
    "            losses.append(loss.item())\n",
    "            accuracies.append(accuracy.item())\n",
    "\n",
    "            if verbose and len(losses) > 0:\n",
    "                avgloss = sum(losses) / len(losses)\n",
    "                avgaccuracy = sum(accuracies) / len(accuracies)\n",
    "                print(\n",
    "                    f\"rank {config.rank} epoch {epoch:5d}/{i:9d} loss {avgloss:8.3f} acc {avgaccuracy:8.3f} {steps:9d}\",\n",
    "                    file=sys.stderr,\n",
    "                )\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            steps += len(labels)\n",
    "            if steps > config.max_steps:\n",
    "                print(\n",
    "                    \"finished training (max_steps)\",\n",
    "                    steps,\n",
    "                    config.max_steps,\n",
    "                    file=sys.stderr,\n",
    "                )\n",
    "                return\n",
    "\n",
    "    print(\"finished Training\", steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A quick smoke test of the training function.\n",
    "\n",
    "config = Config()\n",
    "config.epochs = 1\n",
    "config.max_steps = 1000\n",
    "train(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setting up Distributed Training with Ray\n",
    "\n",
    "Ray is a convenient distributed computing framework. We are using it here to start up the training\n",
    "jobs on multiple GPUs. You can use `torch.distributed.launch` or other such tools as well with the above\n",
    "code. Ray has the advantage that it is runtime environment independent; you set up your Ray cluster\n",
    "in whatever way works for your environment, and afterwards, this code will run in it without change."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@ray.remote(num_gpus=1)\n",
    "def train_on_ray(rank, config):\n",
    "    \"\"\"Set up distributed torch env and train the model on this node.\"\"\"\n",
    "    # Set up distributed PyTorch.\n",
    "    if rank is not None:\n",
    "        os.environ[\"MASTER_ADDR\"] = config.master_addr\n",
    "        os.environ[\"MASTER_PORT\"] = config.master_port\n",
    "        dist.init_process_group(backend=config.backend, rank=rank, world_size=config.world_size)\n",
    "        config.rank = rank\n",
    "        # Ray will automatically set CUDA_VISIBLE_DEVICES for each task.\n",
    "    train(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not ray.is_initialized():\n",
    "    ray.init()\n",
    "\n",
    "ray.available_resources()[\"GPU\"]\n",
    "\n",
    "\n",
    "def distributed_training(config):\n",
    "    \"\"\"Perform distributed training with the given config.\"\"\"\n",
    "    num_gpus = ray.available_resources()[\"GPU\"]\n",
    "    config.world_size = min(config.world_size, num_gpus)\n",
    "    results = ray.get([train_on_ray.remote(i, config) for i in range(config.world_size)])\n",
    "    print(results)\n",
    "\n",
    "\n",
    "config = Config()\n",
    "config.epochs = epochs\n",
    "config.max_steps = max_steps\n",
    "config.batch_size = batch_size\n",
    "print(config)\n",
    "distributed_training(config)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
